﻿"""
Includes a base transport class.

@todo Need to determine how much of the functionality we're going to keep/reuse throw away.
"""


from contextlib import contextmanager
import math

import viz
import vizmat
import vizshape
import vizact
import vizconfig


class Transport(viz.VizNode, vizconfig.Configurable):
	"""
	Base class for all transports.
	"""
	
	def __init__(self, node=None, pivot=None, debug=False, updatePriority=0):
		# Add the node representing the tracker
		self._defaultNode = None
		if node is None:
			if debug:
				self._node = vizshape.addSphere(0.1)
				self._node.color(1, 0, 0)
			else:
				self._node = viz.addGroup()
			self._defaultNode = self._node
		else:
			self._node = node
		viz.VizNode.__init__(self, self._node.id)
		
		self._pivot = pivot
		self._dt = 0
		self.lastBoundPos = self._node.getPosition()
		self.movementBounds_xmin = None
		self.movementBounds_xmax = None
		self.movementBounds_ymin = None
		self.movementBounds_ymax = None
		self.movementBounds_zmin = None
		self.movementBounds_zmax = None
		
		self._frameLocked = 0
		self._lockRequested = 0
		
		self._movementSpeed = vizmat.Vector([1.0]*3)
		
		self._deferred = False
		self._updateFunction = None
		self.updateEvent = vizact.onupdate(updatePriority, self.onUpdate)
		
		# Define a name attribute
		self._name = '{}({})'.format(self.__class__.__name__, self.id)
		
		# Init the base configurable object with above name
		vizconfig.Configurable.__init__(self, name=self._name)
	
	def createConfigUI(self):
		"""Creates the vizconfig configuration ui.
		
		You do not need to call  this function directly.
		"""
		ui = vizconfig.DefaultUI()
		ui.addBoolItem('Enabled', self.updateEvent.setEnabled, self.updateEvent.getEnabled)
		
		return ui
	
	def bindMovement(self, pos, sticky=True):
		"""Binds the movement of the transport into the region given by the
		movement bounds.
		"""
		self.lastBoundPos = self._node.getPosition()
		resultPos = pos
		
		if sticky:# sticky -> if stopped in x, z won't update
			if self.movementBounds_xmin is not None:
				if self.movementBounds_xmin > pos[0]:
					resultPos = self.lastBoundPos
			if self.movementBounds_xmax is not None:
				if self.movementBounds_xmax < pos[0]:
					resultPos = self.lastBoundPos
			if self.movementBounds_ymin is not None:
				if self.movementBounds_ymin > pos[1]:
					resultPos = self.lastBoundPos
			if self.movementBounds_ymax is not None:
				if self.movementBounds_ymax < pos[1]:
					resultPos = self.lastBoundPos
			if self.movementBounds_zmin is not None:
				if self.movementBounds_zmin > pos[2]:
					resultPos = self.lastBoundPos
			if self.movementBounds_zmax is not None:
				if self.movementBounds_zmax < pos[2]:
					resultPos = self.lastBoundPos
		else:# for slippery -> if stopped in x, z will still update
			if self.movementBounds_xmin != None:
				pos[0] = max(self.movementBounds_xmin, pos[0])
			if self.movementBounds_xmax != None:
				pos[0] = min(self.movementBounds_xmax, pos[0])
			if self.movementBounds_ymin != None:
				pos[1] = max(self.movementBounds_ymin, pos[1])
			if self.movementBounds_ymax != None:
				pos[1] = min(self.movementBounds_ymax, pos[1])
			if self.movementBounds_zmin != None:
				pos[2] = max(self.movementBounds_zmin, pos[2])
			if self.movementBounds_zmax != None:
				pos[2] = min(self.movementBounds_zmax, pos[2])
		
		self.lastBoundPos = resultPos
		return resultPos
	
	@contextmanager
	def deferred(self):
		"""Sets the deferred state of the transport, so multiple method
		calls can be queued at the same time.
		"""
		oldDeferred = self._deferred
		self._deferred = True
		yield
		self.finalize()
		self._deferred = oldDeferred
	
	def finalize(self):
		"""Base implementation to finalize queued items, merge operations to
		obtain a final state. This base implementation toggles the frame locked
		and lock requested.
		"""
		self._frameLocked = self._lockRequested
		self._lockRequested = 0
	
	def getNode(self):
		"""Returns the node for the transport.
		@return viz.VizNode()
		"""
		return self._node
	
	def getPivot(self):
		"""Returns the node around which the transport rotates.
		@return viz.VizNode()
		"""
		return self._pivot
	
	def getMovementSpeed(self):
		"""Returns the movement speed of the transport."""
		return self._movementSpeed
	
	def onUpdate(self):
		"""Update callback"""
		if self._updateFunction:
			with self.deferred():
				self._updateFunction(self)
	
	def remove(self):
		"""Removes the transport and any default added nodes"""
		if self._defaultNode:
			self._defaultNode.remove()
		if self.updateEvent is not None:
			self.updateEvent.setEnabled(False)
			self.updateEvent.remove()
	
	def setMovementBounds(self, xmin=None, xmax=None, ymin=None, ymax=None, zmin=None, zmax=None):
		"""Sets the movement bounds of the transport"""
		self.movementBounds_xmin = xmin
		self.movementBounds_xmax = xmax
		self.movementBounds_ymin = ymin
		self.movementBounds_ymax = ymax
		self.movementBounds_zmin = zmin
		self.movementBounds_zmax = zmax
	
	def setMovementSpeed(self, speed, affectVertical=False):
		"""Sets the movement speed of the transport"""
		if viz.islist(speed):
			self._movementSpeed = vizmat.Vector(speed)
		else:
			if affectVertical:
				self._movementSpeed = vizmat.Vector([speed]*3)
			else:
				self._movementSpeed[0] = speed
				self._movementSpeed[2] = speed
	
	def setNode(self, node):
		"""Sets the node used by the transport"""
		if node == self._node:
			return
		pos = self.getPosition(viz.ABS_GLOBAL)
		quat = self.getQuat(viz.ABS_GLOBAL)
		if self.getParents():
			node.setParent(self.getParents()[0])
		node.setPosition(pos, viz.ABS_GLOBAL)
		node.setQuat(quat, viz.ABS_GLOBAL)
		while self.getChildren():
			self.getChildren()[0].setParent(node)
		self._node = node
		self.id = node.id
	
	def setPivot(self, pivot):
		"""Sets the pivot used by the transport. The pivot is the node the transport
		rotates around. So if a transport rotates the pivots position should remain
		constant.
		"""
		self._pivot = pivot
	
	def setUpdateFunction(self, updateFunction):
		"""Sets the function called by the internal update. This function
		should call any move forward, move backward, turn left, etc calls
		needed on a per frame basis.
		"""
		self._updateFunction = updateFunction
	
	# implemented in viz.VizNode for physics-based objects, should be implemented
	# in sub class for non physics-based objects and non physics based objects
#	def setVelocity(self, *args, **kwargs): super(Transport, self).setVelocity(*args, **kwargs)
#	def setAngularVelocity(self, *args, **kwargs): super(Transport, self).setAngularVelocity(*args, **kwargs)
	
	def _adjustForPivot(self, pos, euler):
		"""Adjusts the transport's location so the pivot does not move."""
		# get the position of the pivot
		parentMat = self.getMatrix(viz.ABS_GLOBAL)
		parentMat.preMult(self.getMatrix(viz.ABS_PARENT).inverse())
		parentMat = parentMat.inverse()
		# remove the parent transform from the points
		n = self.getPosition(viz.ABS_GLOBAL)
		p1 = vizmat.Vector(self._pivot.getPosition(viz.ABS_GLOBAL))
		n = parentMat.preMultVec(n)
		p1 = vizmat.Vector(parentMat.preMultVec(p1))
		dp1 = p1 - n
		# get the adjusted position from the rotation change
		mat = vizmat.Transform()
		mat.setQuat(self.getQuat())
		mat.transpose()
		mat.postEuler(euler)
		dp2 = vizmat.Vector(mat.preMultVec(dp1))
		dn = dp2 - dp1
		# remove the difference in the positions from the position update of the transport
		self.setPosition(vizmat.Vector(pos)-dn)
		self.setEuler(euler)
	
	def _updateTime(self):
		"""Internal method which updates the delta time."""
		self._dt = viz.getFrameElapsed()
	
	def __repr__(self):
		"""String representation of class instances"""
		return self._name


class AccelerationTransport(Transport):
	"""Base class for any transports using acceleration."""

	def __init__(self, pivot=None, debug=False,
					acceleration=4.0, # in meters per second per second, lower accelerations can be obtained by using a smaller mag on the input, e.g. pressing the joystick lower
					maxSpeed=10.44, # in meters per second, as a reference 1.4m/s is a typical walking speed, 10.44 is a very fast run
					rotationAcceleration=180.0, # in degrees per second per second
					maxRotationSpeed=90.0, # in degrees per second
					autoBreakingDragCoef=0.1, # determines how quickly the walking transport will stop 
					dragCoef=0.0001,
					rotationAutoBreakingDragCoef=0.2, # determines how quickly the walking transport will stop 
					rotationDragCoef=0.0001, # normal drag coef
					autoBreakingTimeout=0, # how long before auto breaking is enabled
					rotationAutoBreakingTimeout=0, # how long before rotational auto breaking is enabled
					**kwargs):
		
		# init the base class
		super(AccelerationTransport, self).__init__(**kwargs)
		
		self._exp = 1#defines the curvature from the input
		
		self._Vp = vizmat.Vector([0, 0, 0])# note Vp is in global not local coordinates
		self._Vr = vizmat.Vector([0, 0, 0])
		
		self._Ap = vizmat.Vector([0, 0, 0])# note Ap is in local not global coordinates
		self._Ar = vizmat.Vector([0, 0, 0])
		
		self._acceleration = acceleration
		self._maxSpeed = maxSpeed
		self._rotationAcceleration = rotationAcceleration
		self._maxRotationSpeed = maxRotationSpeed
		
		self._autoBreakingDragCoef = autoBreakingDragCoef
		self._dragCoef = dragCoef
		self._rotationAutoBreakingDragCoef = rotationAutoBreakingDragCoef
		self._rotationDragCoef = rotationDragCoef
		
		self._autoBreakingTimeout = autoBreakingTimeout
		self._rotationAutoBreakingTimeout = rotationAutoBreakingTimeout
		self._autoBreakingTimeoutCounter = [0, 0, 0]
		self._rotationAutoBreakingTimeoutCounter = [0, 0, 0]
	
	def setVelocity(self, vel):
		"""Sets the velocity"""
		self._Vp = vizmat.Vector(vel)
	
	def getVelocity(self):
		"""Returns the velocity"""
		return self._Vp
	
	def setAngularVelocity(self, vel):
		"""Sets the angular velocity"""
		self._Vr = vizmat.Vector(vel)
	
	def getAccelerationRotationMatrix(self):
		"""Returns the acceleration rotation matrix, i.e. the matrix used to
		determine reference frame for acceleration.
		
		@return vizmat.Transform()
		"""
		viz.logWarn('*** Warning: transport does not implement getAccelerationRotationMatrix.')
		return vizmat.Transform()
		
	def setRotationAcceleration(self, rotationAcceleration):
		self._rotationAcceleration = rotationAcceleration
	
	def getRotationAcceleration(self):
		return self._rotationAcceleration
		
	def setMaxRotationSpeed(self, maxRotationSpeed):
		self._maxRotationSpeed = maxRotationSpeed
	
	def getMaxRotationSpeed(self):
		return self._maxRotationSpeed	

	def finalize(self):
		"""Method which executes the quequed functions such as
		moveForward and moveBack basing them off the sample orientation
		from the tracker. Should be called regularly either by a timer
		or ideally at every frame.
		"""
		super(AccelerationTransport, self).finalize()
		self._updateTime()
		idt = min(60.0, 1.0/self._dt)
		
		# if necessary normalize the acceleration
		mag = self._Ap.length()
		if mag > 1.0:
			self._Ap = self._Ap / mag
		# .. and for rotation
		mag = self._Ar.length()
		if mag > 1.0:
			self._Ar = self._Ar / mag
		
		# scale acceleration (right now no units just 0-1 range magnitude vector)
		self._Ap *= self._acceleration
		# .. and for rotation
		self._Ar *= self._rotationAcceleration
		
		# get the current position
		pos = self.getPosition()
		euler = self.getEuler()
		
		rotMat = self.getAccelerationRotationMatrix()
		
		invMat = rotMat.inverse()
		
		# we want to have a fast deceleration if we're not moving in a particular direction
		breakingVec = vizmat.Vector(invMat.preMultVec(self._Vp)) * self._autoBreakingDragCoef * idt
		localVp = invMat.preMultVec(self._Vp)
		for i in range(0, 3):
			if self._Ap[i] != 0 and (self._Ap[i]*localVp[i] > 0):
				breakingVec[i] = 0
			if breakingVec[i]:
				if self._autoBreakingTimeoutCounter[i] < self._autoBreakingTimeout:
					breakingVec[i] = 0# cancel breaking
				self._autoBreakingTimeoutCounter[i] += self._dt
			else:
				self._autoBreakingTimeoutCounter[i] = 0
		breakingVec = rotMat.preMultVec(breakingVec)
		
		# now apply the acceleration to the velocity
		drag = self._Vp * self._dragCoef * idt
		adjAp = rotMat.preMultVec(self._Ap)
		
		self._Vp[0] += (adjAp[0] - drag[0] - breakingVec[0]) * self._dt
		self._Vp[1] += (adjAp[1] - drag[1] - breakingVec[1]) * self._dt
		self._Vp[2] += (adjAp[2] - drag[2] - breakingVec[2]) * self._dt
		velMag = self._Vp.length()
		if velMag > self._maxSpeed:
			self._Vp = (self._Vp / velMag) * self._maxSpeed
		
		# .. and for rotation
		breakingVec = self._Vr * self._rotationAutoBreakingDragCoef * idt
		for i in range(0, 3):
			if self._Ar[i] != 0:
				breakingVec[i] = 0
			if breakingVec[i]:
				if self._rotationAutoBreakingTimeoutCounter[i] < self._rotationAutoBreakingTimeout:
					breakingVec[i] = 0# cancel breaking
				self._rotationAutoBreakingTimeoutCounter[i] += self._dt
			else:
				self._rotationAutoBreakingTimeoutCounter[i] = 0
		
		drag = self._Vr * self._rotationDragCoef * idt
		self._Vr[0] += (self._Ar[0] - drag[0] - breakingVec[0]) * self._dt
		self._Vr[1] += (self._Ar[1] - drag[1] - breakingVec[1]) * self._dt
		velMag = self._Vr.length()
		if velMag > self._maxRotationSpeed:
			self._Vr = (self._Vr / velMag) * self._maxRotationSpeed
		
		dp = invMat.preMultVec([
			self._Vp[0] * self._dt,
			self._Vp[1] * self._dt,
			self._Vp[2] * self._dt
		])
		dp[0] *= self._movementSpeed[0]
		dp[1] *= self._movementSpeed[1]
		dp[2] *= self._movementSpeed[2]
		dp = rotMat.preMultVec(dp)
		
		# TODO, not the best implementation if we want to allow distinct x/z scales fix when time allows
		# apply the velocity to the position
		pos[0] += dp[0]
		pos[1] += dp[1]
		pos[2] += dp[2]
		# .. and for rotation
		euler[0] += self._Vr[0] * self._dt
		euler[1] += self._Vr[1] * self._dt
		euler[2] += self._Vr[2] * self._dt
		
		euler[1] = max(-89.9, min(89.9, euler[1]))
		
		# apply the final position update, adjusting for the pivot if necessary
		if self._pivot is None:
			self.setPosition(pos)
			self.setEuler(euler)
		else:
			self._adjustForPivot(pos, euler)
		
		self._Ap[0] = 0
		self._Ap[1] = 0
		self._Ap[2] = 0
		self._Ar[0] = 0
		self._Ar[1] = 0
		self._Ar[2] = 0


class PhysicsTransport(Transport):
	"""Base class for any transport using physics"""
	
	def __init__(self, **kwargs):
		super(PhysicsTransport, self).__init__(**kwargs)
		self._Vp = vizmat.Vector([0, 0, 0])
		self._physManualEuler = [0, 0, 0]
		self._physNode = None
		self._physLink = None
		self._prePhysicsUpdateEvent = None
		self._postPhysicsUpdateEvent = None
		self._defaultPhysicsNode = None
		self._postPhysMat = self.getMatrix()
	
	def getPhysicsLink(self):
		"""Returns the link between the physics node and the transport
		
		@return viz.VizLink()
		"""
		return self._physLink
	
	def getPhysicsNode(self):
		"""Returns the node used for physics. The transport node
		is linked to this object. Only applies if physics is enabled.
		
		@return viz.VizNode()
		"""
		return self._physNode
	
	def setPhysicsNode(self, node):
		"""Sets the state of physics for the node."""
		# remove old add new
		self._physNode = node
		self._physLink.setSrc(node)
	
	def setPhysicsEnabled(self, state):
		"""Sets the state of physics usage for the transport"""
		enabled = state
		if state == viz.TOGGLE:
			enabled = not self._prePhysicsUpdateEvent.getEnabled()
		
		self._physLink.setEnabled(enabled)
		self._prePhysicsUpdateEvent.setEnabled(enabled)
		self._postPhysicsUpdateEvent.setEnabled(enabled)
		if enabled:
			self._node.enable(viz.DYNAMICS)
		else:
			self._node.disable(viz.DYNAMICS)
	
	def _initPhysics(self, physNode=None):
		"""Initializes the state and resources for physics for the transport."""
		self._physNode = physNode
		
		# if no physics node is supplied, create one
		self._defaultPhysicsNode = None
		if self._physNode is None:
			self._defaultPhysicsNode = self._initPhyiscsNode()
		
		# orientation for physics-based movement still computed manually
		self._physManualEuler = [0, 0, 0]
		
		self._prePhysicsUpdateEvent = vizact.onupdate(viz.PRIORITY_PHYSICS-1, self._prePhysics)
		self._postPhysicsUpdateEvent = vizact.onupdate(viz.PRIORITY_PHYSICS+1, self._postPhysics)
		
		# add a matrix to monitor forced changes to the transform
		self._postPhysMat = vizmat.Transform()
		
		# add a link to the transport
		self._physLink = viz.link(self._physNode, self._node)
	
	def _initPhyiscsNode(self):
		"""Internal method which creates a physics node to represent the
		transport. Most likely should be overwritten by a child class.
		"""
		physNode = viz.addGroup()
		physNode.enable(viz.COLLIDE_NOTIFY)
		# add a collide sphere to the default physics node
		physicsHeightOffset = 0.5
		pos = self._node.getPosition()
		pos[1] += physicsHeightOffset*2.0
		physNode.setPosition(pos)
		collideSphere = physNode.collideSphere(radius=physicsHeightOffset)
		collideSphere.setBounce(0.0001)
		collideSphere.setFriction(0.00)
		self._physNode = physNode
	
	def _prePhysics(self):
		"""Update that happens before physics is run."""
		# Get the matrix of the node, if it's not the same as the post physics matrix, then update the euler and 
		mat = self.getMatrix()
		if mat != self._postPhysMat:
			self._physManualEuler = self.getEuler()
			self._physNode.setMatrix(mat)
	
	def _postPhysics(self):
		"""Update that happens after physics is run."""
		self._physNode.setEuler(self._physManualEuler)
		self._Vp = vizmat.Vector(self._physNode.getVelocity())
		self._physNode.setAngularVelocity([0, 0, 0])
		# save matrix to compare with pre-physics matrix of the node on next frame
		self._postPhysMat = self.getMatrix()
		self._node.getMatrix()
	
	def _removePhysics(self):
		"""Removes physics from the node."""
		if self._prePhysicsUpdateEvent is not None:
			self._prePhysicsUpdateEvent.remove()
			self._prePhysicsUpdateEvent = None
		if self._postPhysicsUpdateEvent is not None:
			self._postPhysicsUpdateEvent.remove()
			self._postPhysicsUpdateEvent = None
		if self._defaultPhysicsNode is not None:
			self._defaultPhysicsNode.remove()
			self._defaultPhysicsNode = None
		if self._physLink is not None:
			self._physLink.remove()
			self._physLink = None
